# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""network blocks definition"""

import mindspore
from mindspore import nn, ops

init_mode = 'xavier_uniform'
has_bias = True
bias_init = 'zeros'

class REBNCONV(nn.Cell):
    """
    A basic unit consisting of convolution, batchnorm, and relu activation functions
    """

    def __init__(self, in_ch=3, out_ch=3, dirate=1):
        """definition method"""
        super(REBNCONV, self).__init__()
        self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, dilation=1 * dirate, weight_init=init_mode, has_bias=has_bias,
                                 bias_init=bias_init)
        self.bn_s1 = nn.BatchNorm2d(out_ch, affine=True)
        self.relu_s1 = nn.ReLU()

    def construct(self, x):
        """compute method"""
        hx = x
        hx = self.conv_s1(hx)
        hx = self.bn_s1(hx)
        xout = self.relu_s1(hx)
        return xout


def _upsample_like(src, tar):
    """generate upsample unit"""
    resize_bilinear = mindspore.ops.operations.ResizeBilinear(tar.shape[2:])
    src = resize_bilinear(src)
    return src


class RSU7(nn.Cell):
    """RSU7 block"""

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        """RSU7 definition"""
        super(RSU7, self).__init__()
        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool4 = nn.MaxPool2d(2, stride=2)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool5 = nn.MaxPool2d(2, stride=2)

        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
        self.cat = ops.Concat(axis=1)

    def construct(self, x):
        """RSU7 compute"""
        hx = x
        hxin = self.rebnconvin(hx)
        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)
        hx7 = self.rebnconv7(hx6)
        hx6d = self.rebnconv6d(self.cat((hx7, hx6)))
        hx6dup = _upsample_like(hx6d, hx5)

        hx5d = self.rebnconv5d(self.cat((hx6dup, hx5)))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.rebnconv4d(self.cat((hx5dup, hx4)))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(self.cat((hx4dup, hx3)))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(self.cat((hx3dup, hx2)))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(self.cat((hx2dup, hx1)))

        return hx1d + hxin


class RSU6(nn.Cell):
    """RSU6 block"""

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        """RSU6 definition"""
        super(RSU6, self).__init__()
        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool4 = nn.MaxPool2d(2, stride=2)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
        self.cat = ops.Concat(axis=1)

    def construct(self, x):
        """RSU6 compute"""
        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)

        hx5d = self.rebnconv5d(self.cat((hx6, hx5)))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.rebnconv4d(self.cat((hx5dup, hx4)))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(self.cat((hx4dup, hx3)))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(self.cat((hx3dup, hx2)))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(self.cat((hx2dup, hx1)))

        return hx1d + hxin


class RSU5(nn.Cell):
    """RSU5 block"""

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        """RSU5 definition"""
        super(RSU5, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
        self.cat = ops.Concat(axis=1)

    def construct(self, x):
        """RSU5 compute"""
        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(self.cat((hx5, hx4)))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(self.cat((hx4dup, hx3)))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(self.cat((hx3dup, hx2)))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(self.cat((hx2dup, hx1)))

        return hx1d + hxin


class RSU4(nn.Cell):
    """RSU4 block"""

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        """RSU4 definition"""
        super(RSU4, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
        self.cat = ops.Concat(axis=1)

    def construct(self, x):
        """RSU4 compute"""
        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(self.cat((hx4, hx3)))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(self.cat((hx3dup, hx2)))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(self.cat((hx2dup, hx1)))
        return hx1d + hxin


class RSU4F(nn.Cell):
    """RSU4F block"""

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        """RSU4F definition"""
        super(RSU4F, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
        self.cat = ops.Concat(axis=1)

    def construct(self, x):
        """RSU4F compute"""
        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(self.cat((hx4, hx3)))
        hx2d = self.rebnconv2d(self.cat((hx3d, hx2)))
        hx1d = self.rebnconv1d(self.cat((hx2d, hx1)))

        return hx1d + hxin


class U2NET(nn.Cell):
    """U-2-Net model"""

    def __init__(self, in_ch=3, out_ch=1):
        """U-2-Net definition"""
        super(U2NET, self).__init__()

        self.stage1 = RSU7(in_ch, 32, 64)
        self.pool12 = nn.MaxPool2d(2, stride=2)

        self.stage2 = RSU6(64, 32, 128)
        self.pool23 = nn.MaxPool2d(2, stride=2)

        self.stage3 = RSU5(128, 64, 256)
        self.pool34 = nn.MaxPool2d(2, stride=2)

        self.stage4 = RSU4(256, 128, 512)
        self.pool45 = nn.MaxPool2d(2, stride=2)

        self.stage5 = RSU4F(512, 256, 512)
        self.pool56 = nn.MaxPool2d(2, stride=2)

        self.stage6 = RSU4F(512, 256, 512)

        # decoder
        self.stage5d = RSU4F(1024, 256, 512)
        self.stage4d = RSU4(1024, 128, 256)
        self.stage3d = RSU5(512, 64, 128)
        self.stage2d = RSU6(256, 32, 64)
        self.stage1d = RSU7(128, 16, 64)

        self.side1 = nn.Conv2d(64, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
                               bias_init=bias_init)
        self.side2 = nn.Conv2d(64, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
                               bias_init=bias_init)
        self.side3 = nn.Conv2d(128, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
                               bias_init=bias_init)
        self.side4 = nn.Conv2d(256, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
                               bias_init=bias_init)
        self.side5 = nn.Conv2d(512, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
                               bias_init=bias_init)
        self.side6 = nn.Conv2d(512, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
                               bias_init=bias_init)
        self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
                                 bias_init=bias_init)

        self.cat = ops.Concat(axis=1)
        self.Sigmoid = nn.Sigmoid()
        self.reshape = ops.Reshape()

    def construct(self, x):
        """U-2-Net compute"""
        hx = x

        # stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        # stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        # stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        # stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        # stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        # stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6, hx5)
        # -------------------- decoder --------------------
        hx5d = self.stage5d(self.cat((hx6up, hx5)))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.stage4d(self.cat((hx5dup, hx4)))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.stage3d(self.cat((hx4dup, hx3)))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.stage2d(self.cat((hx3dup, hx2)))
        hx2dup = _upsample_like(hx2d, hx1)
        hx1d = self.stage1d(self.cat((hx2dup, hx1)))
        # side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2, d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3, d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4, d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5, d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6, d1)

        d0 = self.outconv(self.cat((d1, d2, d3, d4, d5, d6)))
        return self.cat((self.Sigmoid(d0), self.Sigmoid(d1), self.Sigmoid(d2), self.Sigmoid(d3), self.Sigmoid(d4),
                         self.Sigmoid(d5), self.Sigmoid(d6)))
